"""Configuration class for YOCO Model."""

from dataclasses import dataclass
from megatron.core.transformer import TransformerConfig
from megatron.training import get_args
from megatron.training.arguments import core_transformer_config_from_args


@dataclass
class YOCOConfig(TransformerConfig):
    """Configuration object for YOCO Model."""

    use_yoco: bool = False
    """Whether to use YOCO model."""


def get_yoco_config() -> YOCOConfig:
    args = get_args()
    config: YOCOConfig = core_transformer_config_from_args(args, YOCOConfig)

    config.use_yoco = args.use_yoco
    config.num_self_attn_layers = args.num_self_attn_layers

    return config
